# -*- coding: utf-8 -*-


import numpy as np
import matplotlib.pyplot as plt
import os 
from scipy.linalg import fractional_matrix_power

#Value to investigate 
kl_penalty = 1
sigma = 0.03

#Retriving error bound and testing error
data_for_prior_training = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
os.chdir("your directory")
if kl_penalty == 1:
         kl_penalty = 1.0
bound_temp_list = []
testing_error_temp_list = [] 
for j in data_for_prior_training: 
                bound = np.load(str(kl_penalty)+'_'+str(sigma)+'_'+str(j)+'_fcn_.npy',allow_pickle=True)[0]
                test = np.load(str(kl_penalty)+'_'+str(sigma)+'_'+str(j)+'_fcn_.npy',allow_pickle=True)[1]
                bound_temp_list.append(bound)
                testing_error_temp_list.append(test)
                
                
#Value to investigate 
sigma = 0.03
kl_penalty = 1
data_for_prior_training = 0.5
sample_per_class = 325
NTK_methods = "ntk_init_withdivnothing" 


#Retriving right 
os.chdir("C:/Users/chunr/Desktop/wei history/20210915数据/不同方法计算出gram matrix/ntk_init_withdivnothing")
rr = np.load('0.03_'+str(sample_per_class)+'_'+str(data_for_prior_training)+'_fcn_.npy',allow_pickle=True)[-2].numpy()
haha = []
for ww in range(10):
    small_list =[] 
    for k in rr[:,ww]:
        temp_small = [ 1 if k == l else -1 for l in rr[:,ww]]
        small_list.append(temp_small)
    haha.append(np.array(small_list))
final_matrix = haha[0]+haha[1]+haha[2]+haha[3]+haha[4]+haha[5]+haha[6]+haha[7]+haha[8]+haha[9]


prior_data = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9] 
temp_list =[]
for i in prior_data: 
    os.chdir("your directory")
    temp_ntk_matrix = np.load('0.03_'+str(sample_per_class)+'_'+str(i)+'_fcn_.npy',allow_pickle=True)[-1].numpy()
    temp_ntk_matrix_with_element = fractional_matrix_power(temp_ntk_matrix + (np.identity(len(temp_ntk_matrix)))*(kl_penalty/sigma),-2)
    temp_align_value = np.trace(final_matrix*temp_ntk_matrix_with_element)
    r_temp_align_value = np.sqrt(temp_align_value/sample_per_class)*(kl_penalty/sigma)
    l_temp_align_value = (np.trace(final_matrix*fractional_matrix_power(temp_ntk_matrix + (np.identity(len(temp_ntk_matrix)))*(kl_penalty/sigma),-2)))*(1/(sigma*sample_per_class))
    temp_align_value = l_temp_align_value + r_temp_align_value
    temp_list.append(temp_align_value)
    
from matplotlib.pyplot import figure
import scipy.stats as stats

colors = np.array(["red","green","black","orange","purple","lime","cyan","magenta",'navy'])#,'pink']) , "dodgerblue","crimson","teal","peru","violet","seagreen","moccasin","darkred"])
labels = np.array(["10% prior data","20% prior data","30% prior data","40% prior data","50% prior data","60% prior data","70% prior data","80% prior data","90% prior data"])#,"50% prior data"]),"55% prior data","60% prior data","65% prior data","70% prior data","75% prior data","80% prior data","85% prior data","90% prior data"])

tau, p_value = stats.kendalltau(temp_list, bound_temp_list)

#figure(figsize=(8, 6), dpi=80)
x = temp_list
y = bound_temp_list
for xx,yy,zz,jj in zip(x,y,colors,labels):
        plt.scatter(xx, yy, c=zz,label=jj)   
plt.ylabel('Error Bound',fontsize=18)
plt.xlabel(r'$\mathcal{PA}$',fontsize=18)
plt.grid(linestyle='-')
plt.legend(loc='upper left', borderaxespad=0.)
legend = plt.legend()
legend.get_frame().set_edgecolor('black')
#plt.title(r"$ {\frac{1}{\sigma_0}}{Y^T (k(X,X)+ \frac{\lambda}{\sigma_0} I)^{-2} Y} + \frac{\lambda}{\sigma_0} \sqrt{ Y^T(k(X,X)+  \frac{\lambda}{\sigma_0} I)^{-2}Y}$")
plt.title(r"Correlation between $\mathcal{PA}$ and bound"'\n' r"under FCN with different prior",fontsize=18)
plt.show()